import argparse
import os
import numpy as np
import torch
import triton
import triton.language as tl
import triton.language.extra.libdevice as tldevice

if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
    exp = tldevice.fast_expf
    exp2 = tldevice.exp2
    log = tldevice.fast_logf
    log2 = tldevice.fast_log2f
else:
    exp = tl.exp
    exp2 = tl.math.exp2
    log = tl.log
    log2 = tl.log2

@triton.heuristics({
    'HAS_WEIGHT': lambda args: args['weight'] is not None,
    'HAS_BIAS': lambda args: args['bias'] is not None,
    'HAS_RESIDUAL': lambda args: args['residual'] is not None,
    'USE_INITIAL_STATE': lambda args: args['initial_state'] is not None,
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.jit
def causal_conv1d_fwd_kernel(
    x,
    y,
    weight,
    bias,
    residual,
    cu_seqlens,
    initial_state,
    chunk_indices,
    B,
    T,
    D: tl.constexpr,
    W: tl.constexpr,
    BT: tl.constexpr,
    BW: tl.constexpr,
    BD: tl.constexpr,
    NB: tl.constexpr,
    ACTIVATION: tl.constexpr,
    HAS_WEIGHT: tl.constexpr,
    HAS_BIAS: tl.constexpr,
    HAS_RESIDUAL: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr,
    IS_VARLEN: tl.constexpr,
):
    i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)

    if IS_VARLEN:
        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
        bos, eos = tl.load(cu_seqlens + i_n), tl.load(cu_seqlens + i_n + 1)
        T = eos - bos
    else:
        i_n = i_b
        bos, eos = i_b * T, i_b * T + T

    o_d = i_d * BD + tl.arange(0, BD)
    o_w = tl.arange(0, BW) + W - BW
    m_d = o_d < D
    m_w = o_w >= 0

    if HAS_WEIGHT:
        # [BD, BW]
        b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None] & m_w, other=0).to(tl.float32)

    b_y = tl.zeros((BT, BD), dtype=tl.float32)
    if not USE_INITIAL_STATE:
        for i_w in tl.static_range(-W + 1, 1):
            p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
            # [BT, BD]
            b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32)
            if HAS_WEIGHT:
                b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
            b_y += b_yi
    elif i_t * BT >= W:
        # to make Triton compiler happy, we need to copy codes
        for i_w in tl.static_range(-W + 1, 1):
            p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
            # [BT, BD]
            b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32)
            if HAS_WEIGHT:
                b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
            b_y += b_yi
    else:
        o_t = i_t * BT + tl.arange(0, BT)
        for i_w in tl.static_range(-W + 1, 1):
            o_x = o_t + i_w
            m_x = ((o_x >= 0) & (o_x < T))[:, None] & m_d
            m_c = ((o_x + W >= 0) & (o_x < 0))[:, None] & m_d

            b_yi = tl.load(x + bos * D + o_x[:, None] * D + o_d, mask=m_x, other=0).to(tl.float32)

            b_yi += tl.load(initial_state + i_n * D*W + o_d * W + (o_x + W)[:, None], mask=m_c, other=0).to(tl.float32)

            if HAS_WEIGHT:
                b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
            b_y += b_yi

    if HAS_BIAS:
        b_y += tl.load(bias + o_d, mask=m_d).to(tl.float32)

    if ACTIVATION == 'swish' or ACTIVATION == 'silu':
        b_y = b_y * tl.sigmoid(b_y)

    if HAS_RESIDUAL:
        p_residual = tl.make_block_ptr(residual + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
        b_residual = tl.load(p_residual, boundary_check=(0, 1))
        b_y += b_residual

    p_y = tl.make_block_ptr(y + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
    tl.store(p_y, tl.cast(b_y, dtype=p_y.dtype.element_ty, fp_downcast_rounding='rtne'), boundary_check=(0, 1))

def test_causal_conv1d_fwd_kernel(output_file):
    # 设置随机种子以确保可重复性
    torch.manual_seed(42)

    # 定义参数
    B = 2        # 批量大小
    T = 64       # 序列长度
    D = 32       # 特征维度
    W = 3        # 滤波器大小
    BT = 16      # block大小 for T
    BW = 3       # block大小 for W
    BD = 8       # block大小 for D
    NB = 4       # block大小 for N

    # 生成随机输入张量
    device = 'npu'  # 使用NPU设备
    dtype = torch.float32  # 使用float32以匹配内核的内部类型

    # 输入张量
    x = torch.randn(B, T, D, dtype=dtype, device=device)
    y = torch.randn(B, T, D, dtype=dtype, device=device)
    weight = torch.randn(D, W, dtype=dtype, device=device)
    bias = torch.randn(D, dtype=dtype, device=device)
    residual = torch.randn(B, T, D, dtype=dtype, device=device)
    initial_state = torch.randn(B, D * W, dtype=dtype, device=device)
    cu_seqlens = torch.randint(low=0, high=T, size=(B + 1,), dtype=torch.int32, device=device)
    cu_seqlens[0] = 0
    cu_seqlens[-1] = T
    chunk_indices = torch.randint(low=0, high=B, size=(2 * T,), dtype=torch.int32, device=device)

    # 计算网格大小
    num_blocks_d = triton.cdiv(D, BD)
    num_blocks_t = triton.cdiv(T, BT)
    grid = (num_blocks_d, num_blocks_t, B)

    # 启用功能标志
    ACTIVATION = 'silu'  # 使用'silu'激活函数
    HAS_WEIGHT = True
    HAS_BIAS = True
    HAS_RESIDUAL = True
    USE_INITIAL_STATE = True
    IS_VARLEN = True

    # 调用内核函数
    causal_conv1d_fwd_kernel[grid](
        x, y, weight, bias, residual, cu_seqlens, initial_state, chunk_indices,
        B=B, T=T, D=D, W=W, BT=BT, BW=BW, BD=BD, NB=NB,
        ACTIVATION=ACTIVATION, HAS_WEIGHT=HAS_WEIGHT, HAS_BIAS=HAS_BIAS,
        HAS_RESIDUAL=HAS_RESIDUAL, USE_INITIAL_STATE=USE_INITIAL_STATE,
        IS_VARLEN=IS_VARLEN
    )

    y_numpy = y.cpu().detach().numpy()
    np.savetxt(output_file, y_numpy.reshape(-1, y_numpy.shape[-1]))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Test Causal Conv1D fwd Kernel')
    parser.add_argument('--output', type=str, default='default_output.txt', 
                        help='Output file name (default: default_output.txt)')
    args = parser.parse_args()
    test_causal_conv1d_fwd_kernel(args.output)